{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# 5.12 稠密连接网络（DenseNet）\n",
    "\n",
    "ResNet中的跨层连接设计引申出了数个后续工作。本节我们介绍其中的一个：稠密连接网络（DenseNet） [1]。 它与ResNet的主要区别如图5.10所示。\n",
    "\n",
    "<div align=center>\n",
    "<img width=\"400\" src=\"../img/chapter05/5.12_densenet.svg\"/>\n",
    "</div>\n",
    "<div align=center>图5.10 ResNet（左）与DenseNet（右）在跨层连接上的主要区别：使用相加和使用连结</div>\n",
    "图5.10中将部分前后相邻的运算抽象为模块$A$和模块$B$。与ResNet的主要区别在于，DenseNet里模块$B$的输出不是像ResNet那样和模块$A$的输出相加，而是在通道维上连结。这样模块$A$的输出可以直接传入模块$B$后面的层。在这个设计里，模块$A$直接跟模块$B$后面的所有层连接在了一起。这也是它被称为“稠密连接”的原因。\n",
    "\n",
    "DenseNet的主要构建模块是稠密块（dense block）和过渡层（transition layer）。前者定义了输入和输出是如何连结的，后者则用来控制通道数，使之不过大。\n",
    "\n",
    "\n",
    "## 5.12.1 稠密块\n",
    "\n",
    "DenseNet使用了ResNet改良版的“批量归一化、激活和卷积”结构，我们首先在`BottleNeck`函数里实现这个结构。在前向计算时，我们将每块的输入和输出在通道维上连结。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "import tensorflow as tf\n",
    "\n",
    "class BottleNeck(tf.keras.layers.Layer):\n",
    "    def __init__(self, growth_rate, drop_rate):\n",
    "        super(BottleNeck, self).__init__()\n",
    "        self.bn1 = tf.keras.layers.BatchNormalization()\n",
    "        self.conv1 = tf.keras.layers.Conv2D(filters=4 * growth_rate,\n",
    "                                            kernel_size=(1, 1),\n",
    "                                            strides=1,\n",
    "                                            padding=\"same\")\n",
    "        self.bn2 = tf.keras.layers.BatchNormalization()\n",
    "        self.conv2 = tf.keras.layers.Conv2D(filters=growth_rate,\n",
    "                                            kernel_size=(3, 3),\n",
    "                                            strides=1,\n",
    "                                            padding=\"same\")\n",
    "        self.dropout = tf.keras.layers.Dropout(rate=drop_rate)\n",
    "        \n",
    "        self.listLayers = [self.bn1,\n",
    "                           tf.keras.layers.Activation(\"relu\"),\n",
    "                           self.conv1,\n",
    "                           self.bn2,\n",
    "                           tf.keras.layers.Activation(\"relu\"),\n",
    "                           self.conv2,\n",
    "                           self.dropout]\n",
    "\n",
    "    def call(self, x):\n",
    "        y = x\n",
    "        for layer in self.listLayers.layers:\n",
    "            y = layer(y)\n",
    "        y = tf.keras.layers.concatenate([x,y], axis=-1)\n",
    "        return y"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "稠密块由多个`BottleNeck`组成，每块使用相同的输出通道数。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "class DenseBlock(tf.keras.layers.Layer):\n",
    "    def __init__(self, num_layers, growth_rate, drop_rate=0.5):\n",
    "        super(DenseBlock, self).__init__()\n",
    "        self.num_layers = num_layers\n",
    "        self.growth_rate = growth_rate\n",
    "        self.drop_rate = drop_rate\n",
    "        self.listLayers = []\n",
    "        for _ in range(num_layers):\n",
    "            self.listLayers.append(BottleNeck(growth_rate=self.growth_rate, drop_rate=self.drop_rate))\n",
    "\n",
    "    def call(self, x):\n",
    "        for layer in self.listLayers.layers:\n",
    "            x = layer(x)\n",
    "        return x"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "在下面的例子中，我们定义一个有2个输出通道数为10的卷积块。使用通道数为3的输入时，我们会得到通道数为$3+2\\times 10=23$的输出。卷积块的通道数控制了输出通道数相对于输入通道数的增长，因此也被称为增长率（growth rate）。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "(4, 8, 8, 23)\n"
     ]
    }
   ],
   "source": [
    "blk = DenseBlock(2, 10)\n",
    "X = tf.random.uniform((4, 8, 8,3))\n",
    "Y = blk(X)\n",
    "print(Y.shape)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 5.12.2 过渡层\n",
    "\n",
    "由于每个稠密块都会带来通道数的增加，使用过多则会带来过于复杂的模型。过渡层用来控制模型复杂度。它通过$1\\times1$卷积层来减小通道数，并使用步幅为2的平均池化层减半高和宽，从而进一步降低模型复杂度。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [],
   "source": [
    "class TransitionLayer(tf.keras.layers.Layer):\n",
    "    def __init__(self, out_channels):\n",
    "        super(TransitionLayer, self).__init__()\n",
    "        self.bn = tf.keras.layers.BatchNormalization()\n",
    "        self.conv = tf.keras.layers.Conv2D(filters=out_channels,\n",
    "                                           kernel_size=(1, 1),\n",
    "                                           strides=1,\n",
    "                                           padding=\"same\")\n",
    "        self.pool = tf.keras.layers.MaxPool2D(pool_size=(2, 2),\n",
    "                                              strides=2,\n",
    "                                              padding=\"same\")\n",
    "\n",
    "    def call(self, inputs):\n",
    "        x = self.bn(inputs)\n",
    "        x = tf.keras.activations.relu(x)\n",
    "        x = self.conv(x)\n",
    "        x = self.pool(x)\n",
    "        return x"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "对上一个例子中稠密块的输出使用通道数为10的过渡层。此时输出的通道数减为10，高和宽均减半。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "TensorShape([4, 4, 4, 10])"
      ]
     },
     "execution_count": 7,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "blk = TransitionLayer(10)\n",
    "blk(Y).shape"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 5.12.3 DenseNet模型\n",
    "\n",
    "我们来构造DenseNet模型。DenseNet首先使用同ResNet一样的单卷积层和最大池化层。类似于ResNet接下来使用的4个残差块，DenseNet使用的是4个稠密块。同ResNet一样，我们可以设置每个稠密块使用多少个卷积层。这里我们设成4，从而与上一节的ResNet-18保持一致。稠密块里的卷积层通道数（即增长率）设为32，所以每个稠密块将增加128个通道。\n",
    "\n",
    "ResNet里通过步幅为2的残差块在每个模块之间减小高和宽。这里我们则使用过渡层来减半高和宽，并减半通道数。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [],
   "source": [
    "class DenseNet(tf.keras.Model):\n",
    "    def __init__(self, num_init_features, growth_rate, block_layers, compression_rate, drop_rate):\n",
    "        super(DenseNet, self).__init__()\n",
    "        self.conv = tf.keras.layers.Conv2D(filters=num_init_features,\n",
    "                                           kernel_size=(7, 7),\n",
    "                                           strides=2,\n",
    "                                           padding=\"same\")\n",
    "        self.bn = tf.keras.layers.BatchNormalization()\n",
    "        self.pool = tf.keras.layers.MaxPool2D(pool_size=(3, 3),\n",
    "                                              strides=2,\n",
    "                                              padding=\"same\")\n",
    "        self.num_channels = num_init_features\n",
    "        self.dense_block_1 = DenseBlock(num_layers=block_layers[0], growth_rate=growth_rate, drop_rate=drop_rate)\n",
    "        self.num_channels += growth_rate * block_layers[0]\n",
    "        self.num_channels = compression_rate * self.num_channels\n",
    "        self.transition_1 = TransitionLayer(out_channels=int(self.num_channels))\n",
    "        self.dense_block_2 = DenseBlock(num_layers=block_layers[1], growth_rate=growth_rate, drop_rate=drop_rate)\n",
    "        self.num_channels += growth_rate * block_layers[1]\n",
    "        self.num_channels = compression_rate * self.num_channels\n",
    "        self.transition_2 = TransitionLayer(out_channels=int(self.num_channels))\n",
    "        self.dense_block_3 = DenseBlock(num_layers=block_layers[2], growth_rate=growth_rate, drop_rate=drop_rate)\n",
    "        self.num_channels += growth_rate * block_layers[2]\n",
    "        self.num_channels = compression_rate * self.num_channels\n",
    "        self.transition_3 = TransitionLayer(out_channels=int(self.num_channels))\n",
    "        self.dense_block_4 = DenseBlock(num_layers=block_layers[3], growth_rate=growth_rate, drop_rate=drop_rate)\n",
    "\n",
    "        self.avgpool = tf.keras.layers.GlobalAveragePooling2D()\n",
    "        self.fc = tf.keras.layers.Dense(units=10,\n",
    "                                        activation=tf.keras.activations.softmax)\n",
    "\n",
    "    def call(self, inputs):\n",
    "        x = self.conv(inputs)\n",
    "        x = self.bn(x)\n",
    "        x = tf.keras.activations.relu(x)\n",
    "        x = self.pool(x)\n",
    "\n",
    "        x = self.dense_block_1(x)\n",
    "        x = self.transition_1(x)\n",
    "        x = self.dense_block_2(x)\n",
    "        x = self.transition_2(x)\n",
    "        x = self.dense_block_3(x)\n",
    "        x = self.transition_3(x,)\n",
    "        x = self.dense_block_4(x)\n",
    "\n",
    "        x = self.avgpool(x)\n",
    "        x = self.fc(x)\n",
    "\n",
    "        return x"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {},
   "outputs": [],
   "source": [
    "def densenet():\n",
    "    return DenseNet(num_init_features=64, growth_rate=32, block_layers=[4,4,4,4], compression_rate=0.5, drop_rate=0.5)\n",
    "mynet=densenet()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "我们尝试打印每个子模块的输出维度确保网络无误："
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "conv2d_45 output shape:\t (1, 48, 48, 64)\n",
      "batch_normalization_45 output shape:\t (1, 48, 48, 64)\n",
      "max_pooling2d_5 output shape:\t (1, 24, 24, 64)\n",
      "dense_block_6 output shape:\t (1, 24, 24, 192)\n",
      "transition_layer_4 output shape:\t (1, 12, 12, 96)\n",
      "dense_block_7 output shape:\t (1, 12, 12, 224)\n",
      "transition_layer_5 output shape:\t (1, 6, 6, 112)\n",
      "dense_block_8 output shape:\t (1, 6, 6, 240)\n",
      "transition_layer_6 output shape:\t (1, 3, 3, 120)\n",
      "dense_block_9 output shape:\t (1, 3, 3, 248)\n",
      "global_average_pooling2d_1 output shape:\t (1, 248)\n",
      "dense output shape:\t (1, 10)\n"
     ]
    }
   ],
   "source": [
    "X = tf.random.uniform(shape=(1,  96, 96 , 1))\n",
    "for layer in mynet.layers:\n",
    "    X = layer(X)\n",
    "    print(layer.name, 'output shape:\\t', X.shape)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 5.12.3 获取数据并训练模型\n",
    "\n",
    "由于这里使用了比较深的网络，本节里我们将输入高和宽从224降到96来简化计算。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "metadata": {
    "scrolled": false
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Train on 48000 samples, validate on 12000 samples\n",
      "Epoch 1/5\n",
      "48000/48000 [==============================] - 244s 5ms/sample - loss: 0.6859 - accuracy: 0.7465 - val_loss: 0.4778 - val_accuracy: 0.8270\n",
      "Epoch 2/5\n",
      "48000/48000 [==============================] - 267s 6ms/sample - loss: 0.3933 - accuracy: 0.8541 - val_loss: 0.3478 - val_accuracy: 0.8716\n",
      "Epoch 3/5\n",
      "48000/48000 [==============================] - 263s 5ms/sample - loss: 0.3312 - accuracy: 0.8783 - val_loss: 0.3403 - val_accuracy: 0.8720\n",
      "Epoch 4/5\n",
      "48000/48000 [==============================] - 240s 5ms/sample - loss: 0.3013 - accuracy: 0.8888 - val_loss: 0.3079 - val_accuracy: 0.8842\n",
      "Epoch 5/5\n",
      "48000/48000 [==============================] - 241s 5ms/sample - loss: 0.2783 - accuracy: 0.8974 - val_loss: 0.2962 - val_accuracy: 0.8913\n",
      "10000/1 - 11s - loss: 0.2877 - accuracy: 0.8848\n"
     ]
    }
   ],
   "source": [
    "(x_train, y_train), (x_test, y_test) = tf.keras.datasets.fashion_mnist.load_data()\n",
    "x_train = x_train.reshape((60000, 28, 28, 1)).astype('float32') / 255\n",
    "x_test = x_test.reshape((10000, 28, 28, 1)).astype('float32') / 255\n",
    "\n",
    "mynet.compile(loss='sparse_categorical_crossentropy',\n",
    "              optimizer=tf.keras.optimizers.Adam(),\n",
    "              metrics=['accuracy'])\n",
    "\n",
    "history = mynet.fit(x_train, y_train,\n",
    "                    batch_size=64,\n",
    "                    epochs=5,\n",
    "                    validation_split=0.2)\n",
    "test_scores = mynet.evaluate(x_test, y_test, verbose=2)"
   ]
  }
 ],
 "metadata": {
  "hide_input": false,
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.0"
  },
  "latex_envs": {
   "LaTeX_envs_menu_present": true,
   "autoclose": true,
   "autocomplete": true,
   "bibliofile": "biblio.bib",
   "cite_by": "apalike",
   "current_citInitial": 1,
   "eqLabelWithNumbers": true,
   "eqNumInitial": 1,
   "hotkeys": {
    "equation": "Ctrl-E",
    "itemize": "Ctrl-I"
   },
   "labels_anchors": false,
   "latex_user_defs": false,
   "report_style_numbering": false,
   "user_envs_cfg": false
  },
  "toc": {
   "base_numbering": 1,
   "nav_menu": {},
   "number_sections": true,
   "sideBar": true,
   "skip_h1_title": false,
   "title_cell": "Table of Contents",
   "title_sidebar": "Contents",
   "toc_cell": false,
   "toc_position": {},
   "toc_section_display": true,
   "toc_window_display": false
  },
  "varInspector": {
   "cols": {
    "lenName": 16,
    "lenType": 16,
    "lenVar": 40
   },
   "kernels_config": {
    "python": {
     "delete_cmd_postfix": "",
     "delete_cmd_prefix": "del ",
     "library": "var_list.py",
     "varRefreshCmd": "print(var_dic_list())"
    },
    "r": {
     "delete_cmd_postfix": ") ",
     "delete_cmd_prefix": "rm(",
     "library": "var_list.r",
     "varRefreshCmd": "cat(var_dic_list()) "
    }
   },
   "position": {
    "height": "525px",
    "left": "923px",
    "right": "20px",
    "top": "127px",
    "width": "353px"
   },
   "types_to_exclude": [
    "module",
    "function",
    "builtin_function_or_method",
    "instance",
    "_Feature"
   ],
   "window_display": false
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
